Skip to content

Fix KV head shape mismatch when TP size exceeds num_kv_heads#3426

Merged
copybara-service[bot] merged 1 commit intomainfrom
mohit/ep_replicate
Apr 2, 2026
Merged

Fix KV head shape mismatch when TP size exceeds num_kv_heads#3426
copybara-service[bot] merged 1 commit intomainfrom
mohit/ep_replicate

Conversation

@khatwanimohit
Copy link
Copy Markdown
Collaborator

@khatwanimohit khatwanimohit commented Mar 16, 2026

Description

This PR was done in collaboration with @NicoGrande

Problem: When serving models like Qwen3-30B-A3B (4 KV heads) with TP=8, adapter.py pads base_num_kv_heads to 8 to match TP size. This caused Orbax to reject checkpoint restore because the stored shape (seq, 4, 128) didn't match the model's padded shape (seq, 8, 128).

Fix:

  • adapter.py: Pad base_num_kv_heads to max_tp_size when num_kv_heads < TP_size, using ShardingAxisName.ATTN_HEAD to determine the relevant mesh axis size.
  • model_creation_utils.py: Add two helpers wired into create_nnx_model:
  • _fix_restore_args_for_shape_mismatch: Before restore, detects arrays whose checkpoint shape differs from the model shape (via Orbax metadata) and switches them to fully-replicated sharding with no global_shape, allowing Orbax to load them in their stored shape.
  • _expand_checkpoint_to_model_shapes: After restore, uses jnp.repeat (not tile) to expand mismatching arrays to the model shape, then re-shards via jax.device_put. repeat is correct for GQA: device i needs KV head i//ratio, producing [h0,h0,h1,h1,...] rather than
    [h0,h1,...,h0,h1,...].

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

CI Tests

vllm_decode for Qwen3-30B-A3B with 4 kv heads on a v6e-8

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Copy Markdown
Collaborator

@NicoGrande NicoGrande left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you test on a VM with 4 Jax devices (maybe v5p-8?). I tried on my v6e-4 and saw nonsense outputs so perhaps this is currently breaking the codepath where no changes are necessary.

Comment thread src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 16, 2026

Codecov Report

❌ Patch coverage is 40.27778% with 43 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/model_creation_utils.py 44.61% 32 Missing and 4 partials ⚠️
src/maxtext/trainers/post_train/rl/train_rl.py 0.00% 5 Missing ⚠️
src/maxtext/layers/moe.py 0.00% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown
Collaborator

@NicoGrande NicoGrande left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Copy Markdown
Collaborator

@xuefgu xuefgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which Orbax version is this PR tested against?

At what scale is this PR tested? Any practical runs on GKE using Pathways?

@khatwanimohit khatwanimohit force-pushed the mohit/ep_replicate branch 4 times, most recently from 04b6506 to 1d2eefa Compare March 18, 2026 17:29
@A9isha A9isha force-pushed the mohit/ep_replicate branch from aebda8a to a2b9861 Compare March 24, 2026 22:22
@khatwanimohit khatwanimohit force-pushed the mohit/ep_replicate branch 3 times, most recently from 3dbfa18 to b1abcf7 Compare April 2, 2026 18:30
@khatwanimohit
Copy link
Copy Markdown
Collaborator Author

Which Orbax version is this PR tested against?

At what scale is this PR tested? Any practical runs on GKE using Pathways?

A run for Qwen3-30B-A3B on 1 slice of v5p-128 https://cloudlogging.app.goo.gl/Rna8hP21D5XKAHoF9

More runs are in this bug: b/498435735

Copy link
Copy Markdown
Collaborator

@xuefgu xuefgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For posterity:
(1) Fast follow - logical axis rules
(2) Currently it works for v5p but not for v7x yet, although we are not sure whether that's because of our stack. See b/496608632.

@copybara-service copybara-service Bot merged commit 5905690 into main Apr 2, 2026
36 of 37 checks passed
@copybara-service copybara-service Bot deleted the mohit/ep_replicate branch April 2, 2026 22:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants